/*
 * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
 *
 * The Universal Permissive License (UPL), Version 1.0
 *
 * Subject to the condition set forth below, permission is hereby granted to any
 * person obtaining a copy of this software, associated documentation and/or
 * data (collectively the "Software"), free of charge and under any and all
 * copyright rights in the Software, and any and all patent rights owned or
 * freely licensable by each licensor hereunder covering either (i) the
 * unmodified Software as contributed to or provided by such licensor, or (ii)
 * the Larger Works (as defined below), to deal in both
 *
 * (a) the Software, and
 *
 * (b) any piece of software and/or hardware listed in the lrgrwrks.txt file if
 * one is included with the Software each a "Larger Work" to which the Software
 * is contributed by such licensors),
 *
 * without restriction, including without limitation the rights to copy, create
 * derivative works of, display, perform, and distribute the Software and make,
 * use, sell, offer for sale, import, export, have made, and have sold the
 * Software and the Larger Work(s), and to sublicense the foregoing rights on
 * either these or other terms.
 *
 * This license is subject to the following condition:
 *
 * The above copyright notice and either this complete permission notice or at a
 * minimum a reference to the UPL must be included in all copies or substantial
 * portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 * SOFTWARE.
 */
package com.oracle.truffle.api.bytecode.debug;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.LongSummaryStatistics;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Predicate;
import java.util.stream.Collectors;

import com.oracle.truffle.api.CompilerDirectives;
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
import com.oracle.truffle.api.bytecode.BytecodeDescriptor;
import com.oracle.truffle.api.bytecode.BytecodeNode;
import com.oracle.truffle.api.bytecode.InstructionDescriptor;
import com.oracle.truffle.api.bytecode.InstructionTracer;
import com.oracle.truffle.api.frame.Frame;

/**
 * Instruction tracer that records per-opcode execution counts and optionally aggregates them into
 * hierarchical histograms by user defined group clauses.
 * <p>
 * The tracer is optimized for the no-filter, no-group case. In that mode the hot path avoids
 * boundary calls and can be partially evaluated. When grouping or filtering is enabled, a small LRU
 * cache avoids repeated evaluation of the grouping and filtering functions for the same
 * {@link BytecodeNode}, current thread, and compilation tier.
 * <p>
 * Thread safety: increments are performed via atomic counters and are safe under concurrent
 * execution. Histogram creation is a snapshot that can be performed concurrently with counting. For
 * interval semantics, prefer {@link #getHistogramAndReset()} over mixing {@link #getHistogram()}
 * with {@link #reset()}.
 * <p>
 * <h3>Basic usage</h3> The example below shows how to attach a {@link HistogramInstructionTracer}
 * directly to a generated bytecode root, execute code, then collect and print a histogram snapshot.
 * </p>
 *
 * <pre>
 * var root = MyRootNodeGen.BYTECODE.create(language, BytecodeConfig.DEFAULT, b -> {
 *     b.beginRoot();
 *     b.beginReturn();
 *     b.emitLoadArgument(0);
 *     b.endReturn();
 *     b.endRoot();
 * }).getNode(0);
 *
 * // Create and attach the histogram tracer to the root.
 * var tracer = HistogramInstructionTracer.newBuilder().build(MyRootNodeGen.BYTECODE);
 * root.getRootNodes().addInstructionTracer(tracer);
 *
 * // Execute your program as usual.
 * Object result = root.getCallTarget().call(42);
 *
 * // Take a consistent snapshot and reset counters for the next interval.
 * var histogram = tracer.getHistogramAndReset();
 *
 * // Inspect or print the histogram.
 * long total = histogram.getInstructionsExecuted();
 * histogram.print(System.out);
 *
 * // Detach the tracer when done.
 * root.getRootNodes().removeInstructionTracer(tracer);
 * </pre>
 *
 * @since 25.1
 */
public final class HistogramInstructionTracer implements InstructionTracer {

    private final BytecodeDescriptor<?, ?, ?> descriptor;
    private final Predicate<BytecodeNode> filterClause;
    private final GroupClause[] groupClauses;

    /*
     * Least recently used cache for avoiding repeated invocations of the bytecode node filter and
     * group clause. This will perform quite poor for heavily concurrent workloads. For those its
     * better to not use any filters or groups.
     */
    private volatile LastTraceCache cache;
    private final Counters counters;
    private final AtomicLong[] rootCounters; // one less indirection

    HistogramInstructionTracer(BytecodeDescriptor<?, ?, ?> descriptor,
                    Predicate<BytecodeNode> bytecodeFilter,
                    GroupClause[] groupClauses) {
        Objects.requireNonNull(descriptor);
        this.descriptor = descriptor;
        this.filterClause = bytecodeFilter;
        this.groupClauses = groupClauses;
        this.counters = new Counters(operationCodeTableSize(descriptor), groupClauses == null);
        this.rootCounters = counters.data;
    }

    private static int operationCodeTableSize(BytecodeDescriptor<?, ?, ?> descriptor) {
        int maxCode = 0;
        for (InstructionDescriptor instructionDescriptor : descriptor.getInstructionDescriptors()) {
            maxCode = Math.max(instructionDescriptor.getOperationCode(), maxCode);
        }
        // if this assertion fails, there is likely a bug or change in the descriptor encoding
        // currently we assume we can fit it all descriptors densely in an opcode table.
        assert maxCode <= descriptor.getInstructionDescriptors().size() + 128 : "descriptor density too sparse";
        return maxCode + 1;
    }

    /**
     * Records the execution of an instruction. This method is not intended to be called directly,
     * but by the bytecode DSL framework as part of instruction tracing.
     *
     * @since 25.1
     */
    public void onInstructionEnter(InstructionAccess access, BytecodeNode bytecode, int bytecodeIndex, Frame frame) {
        assert descriptor.getGeneratedClass() == bytecode.getRootNode().getClass() : "Statistics listener attached to the wrong bytecode descriptor.";
        /*
         * This is carefully crafted to make sure it partially evaluates cleanly if no filter and no
         * grouping are applied. Without filtering and grouping we can avoid boundary calls here,
         * which improves the peak performance of this tracer. This might be needed when tracing big
         * long-running applications.
         */
        boolean filter = this.filterClause != null; // pe-constant
        boolean group = this.groupClauses != null; // pe-constant
        LastTraceCache c;
        AtomicLong[] counterArray;
        if (!filter && !group) {
            // fast-path implementation
            c = null;
            counterArray = this.rootCounters; // pe-constant
        } else {
            c = this.cache;
            int compiledTier = getCompiledTier();
            if ((c == null || c.bytecodeNode != bytecode || c.compiledTier != compiledTier || c.thread != Thread.currentThread())) {
                c = updateCache(bytecode, compiledTier);
                this.cache = c;
            }
            if (group) { // pe-constant
                counterArray = c.counters;
            } else {
                counterArray = this.rootCounters; // pe-constant
            }
            if (filter && !c.included) {
                return;
            }
        }
        counterArray[access.getTracedOperationCode(bytecode, bytecodeIndex)].incrementAndGet();
    }

    private static int getCompiledTier() {
        if (CompilerDirectives.inInterpreter()) {
            return 0;
        } else if (CompilerDirectives.hasNextTier()) {
            return 1;
        } else {
            return 2;
        }
    }

    @TruffleBoundary
    private LastTraceCache updateCache(BytecodeNode bytecode, int compiledTier) {
        Counters c = this.counters;
        if (groupClauses != null) {
            Thread t = Thread.currentThread();
            c = this.counters;
            int length = groupClauses.length;
            for (int i = 0; i < length; i++) {
                c = c.getOrCreateGroup(groupClauses[i].group(bytecode, t, compiledTier), i == length - 1);
            }
        }
        boolean included = filterClause != null ? filterClause.test(bytecode) : true;
        assert c.data != null; // must be leaf
        return new LastTraceCache(bytecode, c.data, included, compiledTier, Thread.currentThread());
    }

    /**
     * Returns the {@link BytecodeDescriptor} that this tracer is attached to.
     *
     * @return the descriptor used to resolve instruction metadata and opcodes
     */
    @Override
    public BytecodeDescriptor<?, ?, ?> getExclusiveBytecodeDescriptor() {
        return descriptor;
    }

    /**
     * Creates a hierarchical histogram snapshot of the recorded counters without resetting them.
     * <p>
     * Use this to inspect cumulative counts. If you need interval semantics, prefer
     * {@link #getHistogramAndReset()}.
     *
     * @return a {@link Histogram} view of the current counters, possibly grouped
     * @since 25.1
     */
    public Histogram getHistogram() {
        return new Histogram(descriptor, counters, false);
    }

    /**
     * Creates a hierarchical histogram snapshot of the recorded counters and atomically resets the
     * underlying counters to zero.
     * <p>
     * Use this to obtain interval counts in long running applications. The reset is performed
     * atomically per-counter, increments that occur after the swap will be visible in the next
     * interval.
     *
     * @return a {@link Histogram} view of the counts since the previous reset
     * @since 25.1
     */
    public Histogram getHistogramAndReset() {
        return new Histogram(descriptor, counters, true);
    }

    /**
     * Renders the current histogram snapshot into a string using
     * {@link #printHistogram(PrintStream)}.
     *
     * @return a human readable histogram table as UTF-8 text
     * @see Histogram#dump()
     * @since 25.1
     */
    public String dumpHistogram() {
        try (ByteArrayOutputStream w = new ByteArrayOutputStream()) {
            printHistogram(new PrintStream(w));
            return w.toString();
        } catch (IOException e) {
            // IOException unexpected here
            throw new AssertionError(e);
        }
    }

    /**
     * Prints the current histogram snapshot to the given {@link PrintStream}.
     * <p>
     * The output is a single table with aligned numeric columns. Groups, if present, are printed as
     * indented pseudo rows followed by their leaves.
     *
     * @param out the stream to print to
     * @see Histogram#print(PrintStream)
     * @since 25.1
     */
    public void printHistogram(PrintStream out) {
        Objects.requireNonNull(out);
        getHistogram().print(out);
    }

    /**
     * Resets all counters to zero and discards previously recorded data.
     * <p>
     * Do not mix this with {@link #getHistogram()} if you need interval correctness, use
     * {@link #getHistogramAndReset()} instead.
     *
     * @since 25.1
     */
    public void reset() {
        this.counters.reset();
    }

    /**
     * Returns a builder for {@link HistogramInstructionTracer} that allows configuring a filter and
     * one or more grouping clauses.
     *
     * @return a new builder instance
     * @since 25.1
     */
    public static Builder newBuilder() {
        return new Builder();
    }

    /**
     * A grouping function that maps a {@link BytecodeNode}, the current {@link Thread}, and the
     * compilation tier to an arbitrary grouping key. Keys are used to form hierarchical histograms.
     *
     * @since 25.1
     */
    @FunctionalInterface
    public interface GroupClause {

        /**
         * Computes the grouping key for the given execution context.
         *
         * @param bytecodeNode the bytecode node executing
         * @param thread the current thread
         * @param compilationTier 0 for interpreter, 1 for compiled with a next tier, 2 for final
         *            compiled
         * @return a non null grouping key, for example a string, an enum, or a small record
         * @since 25.1
         */
        Object group(BytecodeNode bytecodeNode, Thread thread, int compilationTier);

    }

    /**
     * Builder for {@link HistogramInstructionTracer}.
     * <p>
     * The builder supports an optional filter predicate and an ordered list of group clauses. Group
     * clauses are applied in the order they are added, producing a hierarchy of groups.
     *
     * @since 25.1
     */
    public static final class Builder {

        private Predicate<BytecodeNode> filter;
        private List<GroupClause> groups = new ArrayList<>();

        private Builder() {
        }

        /**
         * Sets an optional filter predicate. If present, only bytecode nodes that satisfy the
         * predicate contribute to counters.
         *
         * @param filterClause the predicate to apply, may be {@code null}
         * @return this builder
         *
         * @since 25.1
         */
        public Builder filter(Predicate<BytecodeNode> filterClause) {
            this.filter = filterClause;
            return this;
        }

        /**
         * Adds a grouping clause to the histogram. Multiple clauses create a hierarchy, the first
         * clause forms the top level, the last clause forms the leaves.
         *
         * @param clause the grouping clause to add, must not be {@code null}
         * @return this builder
         * @throws NullPointerException if {@code clause} is {@code null}
         * @since 25.1
         */
        public Builder groupBy(GroupClause clause) {
            Objects.requireNonNull(clause);
            groups.add(clause);
            return this;
        }

        /**
         * Builds a {@link HistogramInstructionTracer} for the given {@link BytecodeDescriptor}.
         *
         * @param descriptor the descriptor to attach to
         * @return a new tracer configured with the previously set filter and groups
         * @since 25.1
         */
        @SuppressWarnings("unchecked")
        public HistogramInstructionTracer build(BytecodeDescriptor<?, ?, ?> descriptor) {
            return new HistogramInstructionTracer(descriptor, filter, groups.isEmpty() ? null : groups.toArray(GroupClause[]::new));
        }
    }

    /**
     * Represents a hierarchical histogram that can group instruction counts by arbitrary
     * attributes. A histogram is an immutable snapshot of the counters at the time it was created.
     *
     * @since 25.1
     */
    public static final class Histogram {

        private final BytecodeDescriptor<?, ?, ?> descriptor;
        private final Map<Object, Histogram> subGroups;
        private final Map<InstructionDescriptor, Long> data;
        private final long instructionsExecuted;

        private Histogram(BytecodeDescriptor<?, ?, ?> descriptor, Counters counters, boolean reset) {
            this.descriptor = descriptor;
            Map<InstructionDescriptor, Long> map = new LinkedHashMap<>();
            long sum = 0;
            if (counters.subgroups != null) {
                Map<Object, Histogram> groups = new LinkedHashMap<>();
                for (var entry : counters.subgroups.entrySet()) {
                    Histogram histogram = new Histogram(descriptor, entry.getValue(), reset);
                    groups.put(entry.getKey(), histogram);
                    sum += sumChild(map, histogram.data);
                }
                this.subGroups = Collections.unmodifiableMap(groups);
            } else {
                this.subGroups = null;
                sum += sumCounters(map, counters.data, reset);
            }
            this.data = Collections.unmodifiableMap(map);
            this.instructionsExecuted = sum;
        }

        /**
         * Returns the per-instruction counts for this histogram level when it is a leaf.
         * <p>
         * If this histogram has subgroups, this map may be empty, since counts then live in the
         * children.
         *
         * @return an unmodifiable map from {@link InstructionDescriptor} to count
         * @since 25.1
         */
        public Map<InstructionDescriptor, Long> getCounters() {
            return data;
        }

        /**
         * Returns aggregated statistics across this histogram level.
         * <p>
         * For leaf histograms, this wraps each count in a {@link LongSummaryStatistics}. For parent
         * histograms, statistics are combined across all children.
         *
         * @return a map from {@link InstructionDescriptor} to aggregated statistics
         * @since 25.1
         */
        public Map<InstructionDescriptor, LongSummaryStatistics> getStatistics() {
            if (subGroups == null) {
                return data.entrySet().stream().collect(Collectors.toMap(
                                Map.Entry::getKey,
                                e -> {
                                    LongSummaryStatistics s = new LongSummaryStatistics();
                                    s.accept(e.getValue());
                                    return s;
                                }));
            } else {
                Map<InstructionDescriptor, LongSummaryStatistics> aggregated = new LinkedHashMap<>();
                for (Histogram child : subGroups.values()) {
                    Map<InstructionDescriptor, LongSummaryStatistics> childStats = child.getStatistics();
                    for (var entry : childStats.entrySet()) {
                        aggregated.computeIfAbsent(entry.getKey(), k -> new LongSummaryStatistics()).combine(entry.getValue());
                    }
                }
                return aggregated;
            }
        }

        /**
         * Returns the subgroups of this histogram, keyed by the grouping values produced by the
         * configured {@link GroupClause}s.
         *
         * @return an unmodifiable map of subgroup keys to child histograms, or {@code null} if
         *         there are no groups
         * @since 25.1
         */
        public Map<Object, Histogram> getGroups() {
            return subGroups;
        }

        /**
         * Renders this histogram to a string using {@link #print(PrintStream)}.
         *
         * @return a human readable histogram table
         * @since 25.1
         */
        public String dump() {
            try (ByteArrayOutputStream w = new ByteArrayOutputStream()) {
                print(new PrintStream(w));
                return w.toString();
            } catch (IOException e) {
                // IOException unexpected here
                throw new AssertionError(e);
            }
        }

        /**
         * Prints this histogram as a single indented table.
         * <p>
         * Numeric columns are aligned. When groups are present, each group is shown as a row that
         * displays the total for the group, followed by the indented rows of its children.
         *
         * @param out the stream to print to
         * @since 25.1
         */
        public void print(PrintStream out) {
            long executed = instructionsExecuted;
            out.println("Instruction histogram for: " + descriptor.getSpecificationClass().getName());
            if (executed == 0) {
                out.println("  (no events)");
                return;
            }

            final String countLabel = "Count";
            final String percentLabel = "Percent";
            final String nameLabel = "Group / Instruction";

            int countWidth = Math.max(countLabel.length(), String.valueOf(executed).length());
            int percentWidth = Math.max(percentLabel.length(), 7); // fits "100.0%"

            String header = String.format("   %" + countWidth + "s  | %" + percentWidth + "s | %s",
                            countLabel, percentLabel, nameLabel);
            String formatString = "   %" + countWidth + "d  | %" + percentWidth + "s | %s%s%n";
            String ruler = "  " + repeat('-', header.length());

            out.println(ruler);
            out.println(header);
            out.println(ruler);

            printRecursive(out, 0, executed, formatString);

            out.println(ruler);
            out.printf("  Total executed instructions: %d%n", executed);
        }

        /**
         * Returns the total number of executed instructions represented by this histogram.
         * <p>
         * For grouped histograms, this is the sum across all children. For leaves, this is the sum
         * of the per-instruction counts.
         *
         * @return the total count represented by this histogram
         * @since 25.1
         */
        public long getInstructionsExecuted() {
            return instructionsExecuted;
        }

        private void printRecursive(PrintStream out, int depth,
                        long globalTotal, String formatString) {
            // Compute total count for this histogram
            long total = this.getInstructionsExecuted();
            if (total == 0) {
                return;
            }

            if (this.subGroups != null) {
                assert !this.subGroups.isEmpty();
                // Sort groups for deterministic output
                List<Map.Entry<Object, Histogram>> entries = new ArrayList<>(this.subGroups.entrySet());
                entries.sort(Map.Entry.comparingByValue(Comparator.comparingLong(Histogram::getInstructionsExecuted).reversed()));

                for (var entry : entries) {
                    Histogram child = entry.getValue();
                    long childTotal = child.getInstructionsExecuted();
                    double percentGlobal = (childTotal * 100.0) / globalTotal;
                    String indent = "  ".repeat(depth) + "\u25B6 ";
                    out.printf(formatString, childTotal, String.format("%.1f", percentGlobal), indent, entry.getKey());
                    child.printRecursive(out, depth + 1, globalTotal, formatString);
                }
            } else {
                // Leaf node: print per-instruction rows
                List<Map.Entry<InstructionDescriptor, Long>> rows = new ArrayList<>(this.data.entrySet());
                rows.sort((a, b) -> Long.compare(b.getValue(), a.getValue()));

                for (var e : rows) {
                    InstructionDescriptor d = e.getKey();
                    long count = e.getValue();
                    double percent = (count * 100.0) / globalTotal;
                    String indent = "  ".repeat(depth);
                    out.printf(formatString, count, String.format("%.1f", percent), indent, formatLabel(d));
                }
            }
        }

        static String formatLabel(InstructionDescriptor instruction) {
            return String.format("%03x %s", instruction.getOperationCode(), instruction.getName());
        }

        private static String repeat(char ch, int n) {
            char[] c = new char[n];
            Arrays.fill(c, ch);
            return new String(c);
        }

        private static long sumChild(Map<InstructionDescriptor, Long> parent,
                        Map<InstructionDescriptor, Long> child) {
            long sum = 0;
            for (var entry : child.entrySet()) {
                Long counter = entry.getValue();
                if (counter == 0) {
                    continue;
                }
                sum += counter;
                parent.compute(entry.getKey(), (key, old) -> old == null ? counter : old + counter);
            }
            return sum;
        }

        private long sumCounters(Map<InstructionDescriptor, Long> statistics, AtomicLong[] values, boolean reset) {
            int opcode = -1;
            long sum = 0;
            for (AtomicLong l : values) {
                opcode++;
                long counter = reset ? l.getAndSet(0L) : l.get();
                if (counter == 0L) {
                    continue;
                }
                InstructionDescriptor d = descriptor.getInstructionDescriptor(opcode);
                assert d != null : "No InstructionDescriptor for opcode=" + opcode + ", counter=" + counter + ", values.length=" + values.length + ", descriptor=" + descriptor;
                if (d == null) {
                    continue;
                }
                sum += counter;
                statistics.put(d, counter);
            }
            return sum;
        }

    }

    private static final class Counters {

        final int size;
        final AtomicLong[] data;
        final ConcurrentHashMap<Object, Counters> subgroups;

        Counters(int tableSize, boolean leaf) {
            this.size = tableSize;
            if (leaf) {
                this.data = createCounters(tableSize);
                this.subgroups = null;
            } else {
                this.data = null;
                this.subgroups = new ConcurrentHashMap<>();
            }
        }

        private static AtomicLong[] createCounters(int size) {
            /*
             * We are using AtomicLong as counters instead of VarHandle + long[] in order to make
             * sure the counter increment can be safely partially evaluated and does not cause any
             * additional overhead in the interpreter.
             */
            AtomicLong[] c = new AtomicLong[size];
            for (int i = 0; i < c.length; i++) {
                c[i] = new AtomicLong();
            }
            return c;
        }

        public Counters getOrCreateGroup(Object childGroup, boolean leaf) {
            Objects.requireNonNull(childGroup, "Group key must not be null.");
            return subgroups.computeIfAbsent(childGroup, (k) -> new Counters(size, leaf));
        }

        public void reset() {
            if (data != null) {
                resetCounters(data);
            }
            if (subgroups != null) {
                for (Counters child : subgroups.values()) {
                    child.reset();
                }
            }
        }

        private static void resetCounters(AtomicLong[] counters) {
            for (AtomicLong counter : counters) {
                counter.getAndSet(0);
            }
        }
    }

    private record LastTraceCache(BytecodeNode bytecodeNode, AtomicLong[] counters, boolean included, int compiledTier, Thread thread) {
    }

}
